{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://githubtocolab.com/CRCTransformers/deepdive-book/blob/main/Chapter-5-T5-sentiment-and-attention.ipynb)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QZbbQmil91Aa"
      },
      "source": [
        "# Description"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PE0-xsUS7vkh"
      },
      "source": [
        "This notebook examines the attention weights of a T5 model fine-tuned for [sentiment span extraction](https://huggingface.co/mrm8488/t5-base-finetuned-span-sentiment-extraction). This T5 model takes a span of text containing positive or negative (or neutral) sentiment and will extract the subsequence containing the sentiment. For example, given the input text `question: negative context: You're a nice person, but your feet stink.`, the model should return the span `your feet stink.`. If the input text replaced \"negative\" with \"positive\", then the model returns the span `nice person,`.\n",
        "\n",
        "We want to see if the attention mechanism highlights the negative sentiment span when the input text asks for negative context and the same for positive sentiment. We'll use the [BertViz](https://github.com/jessevig/bertviz) library to view the attention weights.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5PAh7tlO9xkF"
      },
      "source": [
        "# Environment setup\n",
        "\n",
        "This notebook uses an older version of Huggingface Transformers because the T5 model being used doesn't work with the most recent version."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K6ba41uW7ly5"
      },
      "outputs": [],
      "source": [
        "!pip install transformers==4.11.3 sentencepiece==0.1.96 bertviz==1.0.0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zUXJCWcD7PlS"
      },
      "outputs": [],
      "source": [
        "from transformers import (T5ForConditionalGeneration, \n",
        "                          AutoTokenizer)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qvPK_iU6-AsD"
      },
      "outputs": [],
      "source": [
        "!nvidia-smi"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JVHKdR9W-BO0"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "import torch\n",
        "import transformers\n",
        "import numpy as np\n",
        "\n",
        "print(torch.__version__)\n",
        "print(transformers.__version__)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_a9E-kC5-FwK"
      },
      "outputs": [],
      "source": [
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "peO63ubN-crs"
      },
      "source": [
        "# Model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wyuNI-5Y-fdF"
      },
      "source": [
        "## Loading the model and tokenizer"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bFqP57wT-elr"
      },
      "outputs": [],
      "source": [
        "from transformers import T5ForConditionalGeneration, AutoTokenizer\n",
        "\n",
        "model_name = \"mrm8488/t5-base-finetuned-span-sentiment-extraction\"\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "model = T5ForConditionalGeneration.from_pretrained(model_name)\n",
        "model = model.to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m_2WCuUC-mbu"
      },
      "source": [
        "## Inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Pb6RleE-lYn"
      },
      "outputs": [],
      "source": [
        "def get_sentiment_span(text, sentiment):\n",
        "    \"\"\"\n",
        "    Given a string of text and sentiment type,\n",
        "    return the substring of input text that contains \n",
        "    the specified sentiment.\n",
        "    \"\"\"\n",
        "    query = f\"question: {sentiment} context: {text}\"\n",
        "    input_ids = tokenizer.encode(\n",
        "        query, \n",
        "        return_tensors=\"pt\", \n",
        "        add_special_tokens=True).to(device)\n",
        "    generated_ids = model.generate(\n",
        "        input_ids=input_ids, \n",
        "        num_beams=1, \n",
        "        max_length=80).squeeze()\n",
        "    predicted_span = tokenizer.decode(\n",
        "        generated_ids, \n",
        "        skip_special_tokens=True, \n",
        "        clean_up_tokenization_spaces=True)\n",
        "    return predicted_span"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5PVJk9Pu_V1K"
      },
      "outputs": [],
      "source": [
        "text = \"You're a nice person, but your feet stink.\"\n",
        "get_sentiment_span(text, \"positive\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cRrbWfKh_lf5"
      },
      "outputs": [],
      "source": [
        "get_sentiment_span(text, \"negative\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lxSxWiL4-S-y"
      },
      "source": [
        "# Visualizing attention weights with BertVis"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Xtz2O7ck52Dh"
      },
      "source": [
        "The T5 model has 12 layers, each of which has three attention mechanisms: \n",
        "\n",
        "1. encoder self-attention\n",
        "2. decoder self-attention\n",
        "3. cross-attention\n",
        "\n",
        "Each attention mechanism has 12 heads, and thus has 144 sets of attention weights, one for each choice of layer and attention head. As mentioned above, we'll use the [BertViz](https://github.com/jessevig/bertviz) library to view the weights of the attention heads."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eovMadHhcVHM"
      },
      "outputs": [],
      "source": [
        "from bertviz import head_view\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nLvrYZFNiGT9"
      },
      "outputs": [],
      "source": [
        "def view_cross_attn_heads(text, sentiment, layer=None, heads=None):   \n",
        "    query = f\"question: {sentiment} context: {text}\"\n",
        "    input_ids = tokenizer.encode(\n",
        "        query, \n",
        "        return_tensors=\"pt\", \n",
        "        add_special_tokens=True).to(device)\n",
        "    \n",
        "    with torch.no_grad():\n",
        "        output = model.forward(\n",
        "            input_ids=input_ids, \n",
        "            decoder_input_ids=input_ids, \n",
        "            output_attentions=True, \n",
        "            return_dict=True)\n",
        "\n",
        "    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])\n",
        "    head_view(output.cross_attentions, tokens, layer=layer, heads=heads)\n",
        "\n",
        "\n",
        "def view_decoder_attn_heads(text, sentiment, layer=None, heads=None):    \n",
        "    query = f\"question: {sentiment} context: {text}\"\n",
        "    input_ids = tokenizer.encode(\n",
        "        query, \n",
        "        return_tensors=\"pt\", \n",
        "        add_special_tokens=True).to(device)\n",
        "    \n",
        "    with torch.no_grad():\n",
        "        output = model.forward(\n",
        "            input_ids=input_ids, \n",
        "            decoder_input_ids=input_ids, \n",
        "            output_attentions=True, \n",
        "            return_dict=True)\n",
        "\n",
        "    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])\n",
        "    head_view(output.decoder_attentions, tokens, layer=layer, heads=heads)\n",
        "\n",
        "\n",
        "def view_encoder_attn_heads(text, sentiment, layer=None, heads=None):    \n",
        "    query = f\"question: {sentiment} context: {text}\"\n",
        "    input_ids = tokenizer.encode(\n",
        "        query, \n",
        "        return_tensors=\"pt\", \n",
        "        add_special_tokens=True).to(device)\n",
        "    \n",
        "    with torch.no_grad():\n",
        "        output = model.forward(\n",
        "            input_ids=input_ids, \n",
        "            decoder_input_ids=input_ids, \n",
        "            output_attentions=True, \n",
        "            return_dict=True)\n",
        "\n",
        "    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])\n",
        "    head_view(output.encoder_attentions, tokens, layer=layer, heads=heads)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z5H1e09V6ux5"
      },
      "source": [
        "For starters, we want to see if the weights of any of the attention heads show the word \"positive\" in the input text attending to any of the tokens in the extracted subsequence."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AzFhwb7U7ltA"
      },
      "outputs": [],
      "source": [
        "text = \"You're a nice person, but your feet stink.\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PuitK10YmCi1"
      },
      "outputs": [],
      "source": [
        "view_encoder_attn_heads(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lq2ukYZlig4C"
      },
      "outputs": [],
      "source": [
        "view_decoder_attn_heads(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ddRIkeFIib5Z"
      },
      "outputs": [],
      "source": [
        "view_cross_attn_heads(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4HqvU7t8mSOn"
      },
      "outputs": [],
      "source": [
        "view_encoder_attn_heads(text, \"positive\", layer=6, heads=[11])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PycYC5CInUTV"
      },
      "outputs": [],
      "source": [
        "view_encoder_attn_heads(text, \"negative\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5B5KjdlksFeL"
      },
      "outputs": [],
      "source": [
        "view_cross_attn_heads(text, \"negative\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JVxm6_Q88nSc"
      },
      "source": [
        "Now we'll look at an example with positive *and* negative sentiment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t__kJphsLKBA"
      },
      "outputs": [],
      "source": [
        "text = \"It was the best of times, it was the worst of times.\""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o6QbFh-N7_xI"
      },
      "outputs": [],
      "source": [
        "get_sentiment_span(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "71xMFlvpMDNX"
      },
      "outputs": [],
      "source": [
        "get_sentiment_span(text, \"negative\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p8P0u10a8GEj"
      },
      "outputs": [],
      "source": [
        "view_encoder_attn_heads(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PXt6nY6E9PfA"
      },
      "outputs": [],
      "source": [
        "view_encoder_attn_heads(text, \"negative\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xFSUuAqm9FRU"
      },
      "outputs": [],
      "source": [
        "view_decoder_attn_heads(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-S6SoAwG9kCh"
      },
      "outputs": [],
      "source": [
        "view_decoder_attn_heads(text, \"negative\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x2Df6gvw9G8Y"
      },
      "outputs": [],
      "source": [
        "view_cross_attn_heads(text, \"positive\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "background_save": true
        },
        "id": "v1IwH9PS9V5h"
      },
      "outputs": [],
      "source": [
        "view_cross_attn_heads(text, \"negative\")"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "5PAh7tlO9xkF"
      ],
      "name": "Chapter-5-T5-sentiment-and-attention.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
